import os
import json
import argparse
from azfuse import File

def eval_pope(answers, label_file):
    label_list = [json.loads(q)['label'] for q in File.open(label_file, 'r')]

    for answer in answers:
        text = answer['text'].lower()

        # Only keep the first sentence
        if text.find('.') != -1:
            text = text.split('.')[0]

        text = text.replace(',', '')
        words = text.split(' ')

        if 'i don' in text or 'unanswerable' in words:
            answer['text'] = 'no'
        elif 'yes' in words:
            answer['text'] = 'yes'
        elif 'no' in words or 'not' in words:
            answer['text'] = 'no'
        else:
            answer['text'] = 'yes'

    for i in range(len(label_list)):
        if label_list[i] == 'no':
            label_list[i] = 0
        else:
            label_list[i] = 1

    pred_list = []
    for answer in answers:
        if answer['text'] == 'no':
            pred_list.append(0)
        else:
            pred_list.append(1)

    pos = 1
    neg = 0
    yes_ratio = pred_list.count(1) / len(pred_list)

    TP, TN, FP, FN = 0, 0, 0, 0
    for pred, label in zip(pred_list, label_list):
        if pred == pos and label == pos:
            TP += 1
        elif pred == pos and label == neg:
            FP += 1
        elif pred == neg and label == neg:
            TN += 1
        elif pred == neg and label == pos:
            FN += 1

    print('TP\tFP\tTN\tFN\t')
    print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))

    precision = float(TP) / float(TP + FP)
    recall = float(TP) / float(TP + FN)
    f1 = 2*precision*recall / (precision + recall)
    acc = (TP + TN) / (TP + TN + FP + FN)
    print('Accuracy: {}'.format(acc))
    print('Precision: {}'.format(precision))
    print('Recall: {}'.format(recall))
    print('F1 score: {}'.format(f1))
    print('Yes ratio: {}'.format(yes_ratio))
    print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
    return {
        'f1': f1,
        'acc': acc,
        'precision': precision,
        'recall': recall,
        'yes_ratio': yes_ratio
    }

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--annotation-dir", type=str)
    parser.add_argument("--question-file", type=str)
    parser.add_argument("--result-file", type=str)
    parser.add_argument("--overwrite", action='store_true')
    args = parser.parse_args()

    questions = [json.loads(line) for line in File.open(args.question_file)]
    questions = {question['question_id']: question for question in questions}
    answers = [json.loads(q) for q in File.open(args.result_file)]
    annotation_files = [f for f in File.list(args.annotation_dir)]
    average_f1 = 0.0
    num_f1 = 0
    for file in annotation_files:
        if not os.path.basename(file).startswith('coco_pope_'):
            print(file)
            continue

        if not file.endswith('.json'):
            print(file)
            continue
        category = os.path.basename(file)[10:-5]
        output_file = os.path.join(os.path.dirname(args.result_file), 'eval_result_{}.json'.format(category))

        if File.isfile(output_file) and not args.overwrite:
            with File.open(output_file, 'r') as f:
                scores = json.load(f)
                print('Category: {}'.format(category))
                print('F1: {:.3f}, Acc: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, Yes ratio: {:.3f}'.format(
                    scores['f1'], scores['acc'], scores['precision'], scores['recall'], scores['yes_ratio']
                ))
                average_f1 += scores['f1']
                num_f1 += 1
                print("====================================")
                continue

        cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
        print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
        scores = eval_pope(cur_answers,  file)
        average_f1 += scores['f1']
        num_f1 += 1
        with File.open(output_file, 'w') as f:
            json.dump(scores, f, indent=4)
        print("====================================")
    average_f1 /= num_f1
    print('Average F1: {:.3f}'.format(average_f1))

    output_file = os.path.join(os.path.dirname(args.result_file), 'eval_result_overall.json'.format(category))
    with File.open(output_file, 'w') as f:
        scores = {'average_f1': average_f1}
        json.dump(scores, f, indent=4)

